


import os.path
from collections.abc import Iterable
import torch
import torch.utils.data as data
import numpy as np

import pytorch_kinematics as pk

def tprint(*args):
    """Temporarily prints things on the screen"""
    print("\r", end="")
    print(*args, end="")


def quat_conjugate(a):
    shape = a.shape
    a = a.reshape(-1, 4)
    return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)


# 
def quat_mul(a, b):
    assert a.shape == b.shape
    shape = a.shape
    a = a.reshape(-1, 4)
    b = b.reshape(-1, 4)

    x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3]
    x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3]
    ww = (z1 + x1) * (x2 + y2)
    yy = (w1 - y1) * (w2 + z2)
    zz = (w1 + y1) * (w2 - z2)
    xx = ww + yy + zz
    qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
    w = qq - ww + (z1 - y1) * (y2 - z2)
    x = qq - xx + (x1 + w1) * (x2 + w2)
    y = qq - yy + (w1 - x1) * (y2 + z2)
    z = qq - zz + (z1 + y1) * (w2 - x2)

    quat = torch.stack([x, y, z, w], dim=-1).view(shape)

    return quat

def copysign(a, b):
    # type: (float, Tensor) -> Tensor
    a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])
    return torch.abs(a) * torch.sign(b)


def get_euler_xyz(q):
    qx, qy, qz, qw = 0, 1, 2, 3
    # roll (x-axis rotation)
    sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
    cosr_cosp = q[:, qw] * q[:, qw] - q[:, qx] * \
        q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz]
    roll = torch.atan2(sinr_cosp, cosr_cosp)

    # pitch (y-axis rotation)
    sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
    pitch = torch.where(torch.abs(sinp) >= 1, copysign(
        np.pi / 2.0, sinp), torch.asin(sinp))

    # yaw (z-axis rotation)
    siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
    cosy_cosp = q[:, qw] * q[:, qw] + q[:, qx] * \
        q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz]
    yaw = torch.atan2(siny_cosp, cosy_cosp)

    return roll % (2*np.pi), pitch % (2*np.pi), yaw % (2*np.pi)


def compute_euler_diff_from_quats(quat1, quat2):
    # bsz x 4, with bsz x 4
    quat_mul_res = quat_mul(quat2, quat_conjugate(quat1))
    euler_x, euler_y, euler_z = get_euler_xyz(quat_mul_res)
    euler_xyz = torch.stack(
        [ euler_x, euler_y, euler_z ], dim=-1
    )
    return euler_xyz
    

class ControlSeqStochastic(data.Dataset):
    """
    `LSUN <https://www.yf.io/p/lsun>`_ dataset.

    Args:
        root (string): Root directory for the database files.
        classes (string or list): One of {'train', 'val', 'test'} or a list of
            categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    
    def __init__(self, data_path, obs_type, history_length=4, future_length=2, res=8, config=None, split='train'):
        super(ControlSeqStochastic, self).__init__()
        
        self.data_path = data_path
        self.obs_type = obs_type
        self.history_length = history_length    
        self.future_length = future_length
        self.res = res 
        self.split = split
        self.train_split_factor = 0.9
        self.model_type = config.model.subtype
        
        
        self.config = config
        self.invdyn_input_type = self.config.invdyn.obs_type
        self.diffusion_rep = self.config.model.diffusion_rep # ['link_motion', 'qpos_motion']
        self.data_type = self.config.data_type # ['tracking', 'hora']
        
        self.w_obj_state_history = self.config.invdyn.w_obj_state_history
        self.invdyn_train_obj_motion_pred_model = self.config.invdyn.train_obj_motion_pred_model
        
        
        self.action_type = self.config.invdyn.action_type
        self.mask_out_obj_motion = self.config.invdyn.mask_out_obj_motion

        self.use_obj_motion_norm_command = self.config.invdyn.use_obj_motion_norm_command
        # in ['ori_motion', 'motion_projected_to_mode', 'motion_dir']
        self.obj_motion_format = self.config.invdyn.obj_motion_format 

        self.load_experience_via_mode = self.config.invdyn.load_experience_via_mode
        
        self.train_q_value_model = self.config.invdyn.train_q_value_model
        
        self.train_value_network = self.config.invdyn.train_value_network
        
        self.train_finger_pos_tracking_model = self.config.invdyn.train_finger_pos_tracking_model
        
        self.hist_context_length = self.config.invdyn.hist_context_length
        
        self.add_noise_onto_hist_obs = self.config.invdyn.add_noise_onto_hist_obs
        self.hist_obs_nosie_scale = self.config.invdyn.hist_obs_nosie_scale
        self.invdyn_w_hand_root_ornt = self.config.invdyn.w_hand_root_ornt
        
        
        self.finger_pos_tracking_target_finger_idx = self.config.invdyn.finger_pos_tracking_target_finger_idx # finger pos tracking target finger idx
        
        if self.train_finger_pos_tracking_model:
            self._build_pk_chain()
        
        if self.load_experience_via_mode:
            mode_to_experience_file = np.load(self.data_path, allow_pickle=True).item()
            self.experience_idx_to_data_fn = {}
            self.tot_experience_idx = 0
            for cur_mode in mode_to_experience_file:
                cur_mode_experience_file = mode_to_experience_file[cur_mode]
                
                cur_mode_files_list = cur_mode_experience_file.split("AND")
                for cur_mode_cur_file in cur_mode_files_list:
                
                    cur_idx_to_data_fn = np.load(cur_mode_cur_file, allow_pickle=True).item()
                    for cur_idx in cur_idx_to_data_fn:
                        self.experience_idx_to_data_fn[self.tot_experience_idx] = (cur_idx_to_data_fn[cur_idx], cur_mode)
                        self.tot_experience_idx += 1
        else:
            tot_data_path = self.data_path.split("AND")
            self.experience_idx_to_data_fn = {}
            self.tot_experience_idx = 0
            
            for cur_data_path in tot_data_path:
                cur_idx_to_data_fn = np.load(cur_data_path, allow_pickle=True).item()
                for cur_idx in cur_idx_to_data_fn:
                    self.experience_idx_to_data_fn[self.tot_experience_idx] = cur_idx_to_data_fn[cur_idx]
                    self.tot_experience_idx += 1
        
        
        self.mixed_sim_real_experiences = self.config.invdyn.mixed_sim_real_experiences
        self.real_experiences_traj_idx_to_file_name = self.config.invdyn.real_experiences_traj_idx_to_file_name
        if self.mixed_sim_real_experiences:
            tot_real_traj_idx_to_data_fn_list = self.real_experiences_traj_idx_to_file_name.split('AND')
            real_experience_idx = 0
            real_traj_idx_to_data_fn = {}
            for cur_real_fn in tot_real_traj_idx_to_data_fn_list:
                cur_real_traj_idx_to_data_fn = np.load(cur_real_fn, allow_pickle=True).item()
                for cur_traj_idx in cur_real_traj_idx_to_data_fn:
                    real_traj_idx_to_data_fn[real_experience_idx] = cur_real_traj_idx_to_data_fn[cur_traj_idx]
                    real_experience_idx += 1
            self.real_traj_idx_to_data_fn = real_traj_idx_to_data_fn
            
            # self.real_traj_idx_to_data_fn = np.load(self.real_experiences_traj_idx_to_file_name, allow_pickle=True).item()
            repeat_times = self.tot_experience_idx // len(self.real_traj_idx_to_data_fn) # 1 : 1
            for i_repeat in range(repeat_times):
                for cur_traj_idx in self.real_traj_idx_to_data_fn:
                    cur_data_fn = self.real_traj_idx_to_data_fn[cur_traj_idx]
                    self.experience_idx_to_data_fn[self.tot_experience_idx] = cur_data_fn
                    self.tot_experience_idx += 1
                    
        
        # self.experience_idx_to_data_fn = np.load(self.data_path, allow_pickle=True).item()
        
        # self.length = len(self.experience_idx_to_data_fn)
        self.per_experience_data_nn = 400 - 2
        self.nn_experience = len(self.experience_idx_to_data_fn)
        
        self.length = self.nn_experience * self.per_experience_data_nn
        self.use_relative_target = self.config.invdyn.use_relative_target
        
        
        #### Extrin prediction setting ####
        self.pred_extrin = self.config.invdyn.pred_extrin
        # self.extrin_history_length = 50
        self.extrin_history_length = 30
        
        self.hand_dof_lower = torch.from_numpy(np.array([
            -0.3140, -1.0470, -0.5060, -0.3660, -0.3490, -0.4700, -1.2000, -1.3400,
            -0.3140, -1.0470, -0.5060, -0.3660, -0.3140, -1.0470, -0.5060, -0.3660
        ]))
        self.hand_dof_upper = torch.from_numpy(np.array([
            2.2300, 1.0470, 1.8850, 2.0420, 2.0940, 2.4430, 1.9000, 1.8800, 2.2300,
            1.0470, 1.8850, 2.0420, 2.2300, 1.0470, 1.8850, 2.0420
        ]))
    
        self.relative_qtars_scale_coef = 24
        
        self.normalize_input = self.config.invdyn.normalize_input
        self.normalize_output = self.config.invdyn.normalize_output
        
        self.obj_state_predictor = self.config.invdyn.obj_state_predictor
        
        self.pk_joint_names = ['1', '0', '2', '3', '5', '4', '6', '7', '9', '8', '10', '11', '12', '13', '14', '15']
        self.isaacgym_joint_names =  ['1', '0', '2', '3', '12', '13', '14', '15', '5', '4', '6', '7', '9', '8', '10', '11']
        self.idxes_from_pk_to_isaacgym = [self.pk_joint_names.index(cur_joint_name) for cur_joint_name in self.isaacgym_joint_names]
        self.idxes_from_pk_to_isaacgym = np.array(self.idxes_from_pk_to_isaacgym).astype(np.int32)
        
        # 
        self.idxes_from_isaacgym_to_pk = [self.isaacgym_joint_names.index(cur_joint_name) for cur_joint_name in self.pk_joint_names]
        self.idxes_from_isaacgym_to_pk = np.array(self.idxes_from_isaacgym_to_pk).astype(np.int32)

    def _build_pk_chain(self, ):
        leap_urdf_path = "../RL/assets/leap_hand/leap_hand_right.urdf"
        chain = pk.build_chain_from_urdf(open(leap_urdf_path).read()) 
        chain = chain.to(dtype=torch.float32, device='cpu')
        
        self.chain = chain
        self.isaac_order_to_pk_order = [_ for _ in range(4)] + [_ + 8 for _ in range(0, 8)] + [4, 5, 6, 7]
        self.isaac_order_to_pk_order = torch.tensor(self.isaac_order_to_pk_order, dtype=torch.long, device='cpu')
        self.fingertip_names = [
            'index_tip_head', 'thumb_tip_head', 'middle_tip_head', 'ring_tip_head'
        ]
        self.target_finger_joint_idxes = [
            _ for _ in range(4 * self.finger_pos_tracking_target_finger_idx , (self.finger_pos_tracking_target_finger_idx + 1) * 4 )
        ]
        self.target_finger_joint_idxes = torch.tensor(self.target_finger_joint_idxes, dtype=torch.long)
    
    def _forward_pk_chain_for_finger_pos(self, joint_angles):
        # NOTE: joint angles passed here are full joint angles
        
        tg_batch = self.chain.forward_kinematics(joint_angles) # joint angles 
        target_fingertip_name = self.fingertip_names[self.finger_pos_tracking_target_finger_idx]
        target_fingertip_pos = tg_batch[target_fingertip_name].get_matrix()[:, :3, 3] 
        return target_fingertip_pos
        
    
    def _load_single_experience_freehand(self, data_fn, cur_data_mode):
        
        # loaded_data = {
        data = np.load(data_fn, allow_pickle=True).item()
        state = data['state']
        motion = data['motion']
        qtars = data['qtars']
        state = state[self.idxes_from_pk_to_isaacgym]
        motion = motion[self.idxes_from_pk_to_isaacgym]
        qtars = qtars[self.idxes_from_pk_to_isaacgym] # predict the qtars in the isaacgy'e joint order #
        
        state = torch.from_numpy(state).float()
        motion = torch.from_numpy(motion).float()
        qtars = torch.from_numpy(qtars).float()
        
        state = state.unsqueeze(0).repeat(self.history_length, 1).contiguous()
        motion = motion.unsqueeze(0).repeat(self.future_length, 1).contiguous()
        qtars = qtars.unsqueeze(0).repeat(self.future_length, 1).contiguous()
        
        # state, motion, qtars #
        future_obj_euler_diff = torch.zeros((self.future_length, 3)).float()
        
        loaded_data = {
            # state is the last state, motion is the input qpos motion -- 
            'history_hand_qpos': state.view(-1).contiguous(),
            'future_hand_qpos_motion': motion.view(-1).contiguous(),
            'future_hand_qtars': qtars.view(-1).contiguous(),
            'history_hand_qtars': state.view(-1).contiguous(),
            'future_obj_euler_diff': future_obj_euler_diff.view(-1).contiguous(),
            'mode': 'hand'
        }
        return loaded_data



    def _load_experiences_stochastic(self, data_idx, ts_idx):
        data_fn = self.experience_idx_to_data_fn[data_idx ] 
        
        if isinstance(data_fn, tuple):
            data_fn, cur_data_mode = data_fn
            if cur_data_mode == 'hand':
                loaded_data = self._load_single_experience_freehand(data_fn, cur_data_mode)
                return loaded_data
        else:
            cur_data_mode = None
        
        
        data = np.load(data_fn, allow_pickle=True).item()
        
        if 'qpos' in data:
            tot_states = data['qpos'] # nn_ts x nn_hand_dof #
            tot_qtars = data['qtars'] # nn_ts x nn_hand_dof #
            tot_obj_pose = np.array([0, 0, 0, 0, 0, 0, 1], dtype=np.float32).reshape(1, 7).repeat(tot_states.shape[0], axis=0) # nn_ts x 7 #
        else:
            tot_states = data['shadow_hand_dof_pos']
            tot_qtars = data['shadow_hand_dof_tars'] # nn_ts x nn_hand_dof #
            tot_obj_pose = data['object_pose'] # nn_ts x 7 #
            if self.w_obj_state_history:
                tot_link_pos = data['link_pos'] # nn_ts x nn_links x 3 #
                
        if self.invdyn_w_hand_root_ornt:
            hand_root_ornt = data['hand_root_ornt']
            hand_root_ornt = torch.from_numpy(hand_root_ornt[0]).float()
    
        tot_states = torch.from_numpy(tot_states).float()
        tot_qtars = torch.from_numpy(tot_qtars).float()
        tot_obj_pose = torch.from_numpy(tot_obj_pose).float()
        if self.w_obj_state_history:
            tot_link_pos = torch.from_numpy(tot_link_pos).float()
        
        
        nn_ts = tot_states.shape[0]
        rand_ts = ts_idx
        
        rand_ts = rand_ts % (tot_states.shape[0] - 2)
        ts_idx = ts_idx % (tot_states.shape[0] - 2)
        
        tot_history_ts = [ rand_ts - self.history_length + i + 1 for i in range(0, self.history_length) ]
        tot_history_ts = [ max(0, cur_ts) for cur_ts in tot_history_ts ]
        tot_history_ts = torch.tensor(tot_history_ts).long() # (nn_history_len, )
        # nn_ts x nn_envs x nn_feature_dim #
        # history_real_arm_pos = real_arm_pos[tot_history_ts]
        # history_real_leap_pos_to_sim = real_leap_pos_to_sim[tot_history_ts]
        # # history_real_object_pose = real_object_pose[tot_history_ts]
        
        # history_actions = tot_qtars[tot_history_ts]
        
        tot_future_ts = [rand_ts + i + 1 for i in range(0, self.future_length)]
        tot_future_ts = [min(nn_ts - 1, cur_ts) for cur_ts in tot_future_ts]
        tot_future_ts = torch.tensor(tot_future_ts).long()
        # future_cur_step_already_execute_actions = cur_step_already_execute_actions[tot_future_ts]
        
        # future_real_arm_pos = real_arm_pos[tot_future_ts]
        # future_real_leap_pos_to_sim = real_leap_pos_to_sim[tot_future_ts]
        # future_hand_qpos = torch.cat([ future_real_arm_pos, future_real_leap_pos_to_sim ], dim=-1)
    
        future_object_pose = tot_obj_pose[tot_future_ts]
        history_object_pose = tot_obj_pose[tot_history_ts]
        
        tot_ref_ts = [rand_ts + i for i in range(0, self.future_length)]
        tot_ref_ts = [ min(nn_ts - 1, cur_ts) for cur_ts in tot_ref_ts]
        tot_ref_ts = torch.tensor(tot_ref_ts, ).long()
        # goal_hand_qpos = tot_goal_hand_qpos[tot_ref_ts] 
        
        history_qtars = tot_qtars[tot_history_ts]
        history_qpos = tot_states[tot_history_ts]
        future_qpos = tot_states[tot_future_ts]
        expanded_future_qpos = torch.cat(
            [ history_qpos[-1:], future_qpos ], dim=0
        )
        future_qpos_motion = expanded_future_qpos[1:] - expanded_future_qpos[:-1]
        
        
        
        future_qtars = tot_qtars[tot_future_ts]
        
        if self.action_type == 'relative':
            expanded_future_qtars = torch.cat(
                [ history_qtars[-1:], future_qtars ], dim=0
            )
            future_qtars = expanded_future_qtars[1:] - expanded_future_qtars[:-1]
            
        
        expanded_future_obj_pose = torch.cat(
            [ history_object_pose[-1: ], future_object_pose ], dim=0
        )
        prev_rot = expanded_future_obj_pose[:-1, 3:]
        nex_rot = expanded_future_obj_pose[1:, 3:]
        
        prev_nex_rot_euler_diff = compute_euler_diff_from_quats(prev_rot, nex_rot)
        # prev_nex_rot_euler_diff = prev_nex_rot_euler_diff.contiguous().view(prev_rot.size(0), prev_rot.size(1), -1).contiguous()
        prev_nex_rot_euler_diff[prev_nex_rot_euler_diff >=  np.pi] -= 2 * np.pi
        prev_nex_rot_euler_diff[prev_nex_rot_euler_diff < -np.pi] += 2 * np.pi
        # prev_nex_rot_euler_diff = prev_nex_rot_euler_diff.contiguous().transpose(1, 0).contiguous()
        # prev_nex_rot_euler_diff = prev_nex_rot_euler_diff.contiguous().view(prev_nex_rot_euler_diff.size(0), -1).contiguous()


        last_hist_obj_rot = history_object_pose[-1, 3:]
        first_hist_obj_rot = history_object_pose[0, 3:]
        last_hist_obj_rot_diff_from_prev = compute_euler_diff_from_quats(first_hist_obj_rot.unsqueeze(0), last_hist_obj_rot.unsqueeze(0)).squeeze(0)
        last_hist_obj_rot_diff_from_prev[last_hist_obj_rot_diff_from_prev >= np.pi] -= 2 * np.pi
        last_hist_obj_rot_diff_from_prev[last_hist_obj_rot_diff_from_prev < -np.pi] += 2 * np.pi


        if self.w_obj_state_history:
            history_link_pos = tot_link_pos[tot_history_ts]
            history_palm_link_pos = history_link_pos[:, 0, :] 
            canon_history_object_pose = history_object_pose.clone()
            canon_history_object_pose[:, :3]   -= history_palm_link_pos
            canon_history_object_pose = canon_history_object_pose.contiguous().view(canon_history_object_pose.size(0), -1).contiguous()


        
        if self.pred_extrin:
            
            # history_qpos = self._unscale(history_qpos, self.hand_dof_lower, self.hand_dof_upper).float()
            # # if not self.notscale_targets:
            # #     history_qtars = self._unscale(history_qtars, self.hand_dof_lower, self.hand_dof_upper).float()
            # if self.use_relative_target:
            #     expanded_future_qtars = torch.cat(
            #         [ history_qtars[-1:] ,future_qtars ], dim=0
            #     )
            #     relative_qtars = expanded_future_qtars[1:] - expanded_future_qtars[:-1] 
            #     # calculate the slaed relative future qtars #
            #     future_qtars = relative_qtars * float(self.relative_qtars_scale_coef)
            # else:
            #     future_qtars = self._unscale(future_qtars, self.hand_dof_lower, self.hand_dof_upper).float()
                
            
            if self.normalize_input:
                history_qpos = self._unscale(history_qpos, self.hand_dof_lower, self.hand_dof_upper).float()
            if self.normalize_output:
                future_qtars = self._unscale(future_qtars, self.hand_dof_lower, self.hand_dof_upper).float()
            # TODO: construct the extrin preidction history #
            
            
            # extrin_history_length = 30
            # tot_extrin_history_ts = [ rand_ts - extrin_history_length + i + 1 for i in range(0, extrin_history_length) ]
            # tot_extrin_history_ts = [ max(0, cur_ts) for cur_ts in tot_extrin_history_ts ]
            # tot_extrin_history_ts = torch.tensor(tot_extrin_history_ts).long() 
            
            # extrin_history_qpos = tot_states[tot_extrin_history_ts]
            # extrin_history_qtars = tot_qtars[tot_extrin_history_ts]
            # extrin_history_qpos = self._unscale(extrin_history_qpos, self.hand_dof_lower, self.hand_dof_upper).float()
            # extrin_history_qtars = self._unscale(extrin_history_qtars, self.hand_dof_lower, self.hand_dof_upper).float()
            
        else:
            
            if self.add_noise_onto_hist_obs:
                history_qpos = history_qpos + torch.randn_like(history_qpos) * self.hist_obs_nosie_scale
                history_qtars = history_qtars + torch.randn_like(history_qtars) * self.hist_obs_nosie_scale
            
            if self.normalize_input:
                history_qpos = self._unscale(history_qpos, self.hand_dof_lower, self.hand_dof_upper).float()
            if self.normalize_output:
                future_qtars = self._unscale(future_qtars, self.hand_dof_lower, self.hand_dof_upper).float()
        
        history_qpos = history_qpos.contiguous().view(-1).contiguous()
        future_qpos_motion = future_qpos_motion.contiguous().view(-1).contiguous()
        future_qtars = future_qtars.contiguous().view(-1).contiguous()
        history_qtars = history_qtars.contiguous().view(-1).contiguous()
        prev_nex_rot_euler_diff = prev_nex_rot_euler_diff.contiguous().view(-1).contiguous()
        future_qpos = future_qpos.contiguous().view(-1).contiguous()
        
        
        if self.pred_extrin:
            extrin = data['extrin'] # nn_ts x nn_extrin_emb_dim
            cur_extrin = extrin[ts_idx + 1] # nn_extrin_emb_dim
            
            hist_extrin_idxes = [ rand_ts - self.extrin_history_length + i + 1 for i in range(0, self.extrin_history_length) ]
            hist_extrin_idxes = [ max(0, cur_ts) for cur_ts in hist_extrin_idxes ]
            hist_extrin_idxes = torch.tensor(hist_extrin_idxes).long() # (nn_history_len, )
        
            hist_extrin_qpos = tot_states[hist_extrin_idxes]
            
            hist_extrin_qpos = self._unscale(hist_extrin_qpos, self.hand_dof_lower, self.hand_dof_upper).float()
            
            hist_extrin_qtars = tot_qtars[hist_extrin_idxes]
            hist_extrin = torch.cat(
                [ hist_extrin_qpos, hist_extrin_qtars ], dim=-1
            )
        
        if self.mask_out_obj_motion:
            prev_nex_rot_euler_diff = prev_nex_rot_euler_diff * 0.0
        
        if self.use_obj_motion_norm_command:
            assert not self.mask_out_obj_motion, "use_obj_motion_norm_command and mask_out_obj_motion cannot be used together"
            unexpanded_nex_rot_euler_diff = prev_nex_rot_euler_diff.contiguous().view(self.future_length, -1).contiguous()
            norm_unexpanded_nex_rot_euler_diff = torch.norm(unexpanded_nex_rot_euler_diff, p=2, dim=-1, keepdim=True)
            unexpanded_nex_rot_euler_diff[:, :] = norm_unexpanded_nex_rot_euler_diff.contiguous().repeat(1, unexpanded_nex_rot_euler_diff.size(1)).clone()
            prev_nex_rot_euler_diff = unexpanded_nex_rot_euler_diff.contiguous().view(-1).contiguous()
        
        
        
        if cur_data_mode is not None:
            ########## cur_data_mode_tensor, rot_dir, ang_vel ##########
            ### Data mode tensor --- that describes the rotation direction that we want the object to achieve during the training ###
            cur_data_mode_tensor = torch.tensor(cur_data_mode, dtype=torch.float32).unsqueeze(0).repeat(self.future_length, 1).contiguous() # (3,) -- data mode tensor 
            cur_data_mode_tensor = cur_data_mode_tensor.contiguous().view(-1).contiguous()
            
            
            ### object rotation velocity ---- computed as the velocity from existing data ###
            unexpanded_rot_diff = prev_nex_rot_euler_diff.contiguous().view(self.future_length, -1).contiguous()
            # 0.0083333 
            unexpanded_rot_dir = unexpanded_rot_diff / torch.clamp(torch.norm(unexpanded_rot_diff, p=2, dim=-1, keepdim=True), min=1e-6) # (nn_future_length x 3)
            rot_dir = unexpanded_rot_dir.contiguous().view(-1).contiguous()
            d_time = 0.0083333
            ctl_freq = 6
            unexpanded_ang_vel = unexpanded_rot_diff / (d_time * ctl_freq)
            ang_vel = unexpanded_ang_vel.contiguous().view(-1).contiguous() # # 
            ########## cur_data_mode_tensor, rot_dir, ang_vel ##########
            
            obj_tracking_ref_dict = {
                'cur_data_mode_tensor': cur_data_mode_tensor,
                'rot_dir': rot_dir,
                'ang_vel': ang_vel
            }
        else:
            obj_tracking_ref_dict = {}
        
        
        
        if 'rot_axis' in data:
            data_rot_axis = data['rot_axis']
            cur_data_mode = data_rot_axis[tot_future_ts]
            # NOTE: that's a HACK for the rotation axis command #
            cur_data_mode = cur_data_mode[0] #  * (-1) # rotation axis for the whole episode is fixed #
        # ['ori_motion', 'motion_dir', 'motion_projected_to_mode']
        
        if self.obj_motion_format == 'rot_axis':
            prev_nex_rot_euler_diff = torch.tensor(cur_data_mode, dtype=torch.float32)
            prev_nex_rot_euler_diff = prev_nex_rot_euler_diff.contiguous().unsqueeze(0).repeat(self.future_length, 1).contiguous().view(-1).contiguous()
        elif self.obj_motion_format == 'ori_motion':
            prev_nex_rot_euler_diff = prev_nex_rot_euler_diff
        elif self.obj_motion_format == 'motion_projected_to_mode':
            prev_nex_rot_euler_diff_expanded = prev_nex_rot_euler_diff.contiguous().view(self.future_length, -1).contiguous()
            cur_data_mode_th = torch.tensor(cur_data_mode, dtype=torch.float32) # (3,) -- cur_data_mode
            # nn_future_length x 3 #
            projected_euler_diff = torch.sum(cur_data_mode_th.unsqueeze(0) * prev_nex_rot_euler_diff_expanded, dim=-1, keepdim=True) * cur_data_mode_th.unsqueeze(0)
            prev_nex_rot_euler_diff = projected_euler_diff.contiguous().view(-1).contiguous()
        elif self.obj_motion_format == 'motion_dir':
            prev_nex_rot_euler_diff = torch.tensor(cur_data_mode, dtype=torch.float32).unsqueeze(0).repeat(self.future_length, 1).contiguous() # (3,) -- cur_data_mode
            prev_nex_rot_euler_diff = prev_nex_rot_euler_diff.contiguous().view(-1).contiguous()
        elif self.obj_motion_format == 'motion_angvel':
            expanded_prev_nex_rot_euler_diff = prev_nex_rot_euler_diff.contiguous().view(self.future_length, -1).contiguous()
            d_time = 0.0083333
            ctl_freq = 6
            expanded_prev_nex_rot_euler_diff = expanded_prev_nex_rot_euler_diff / (d_time * ctl_freq)
            prev_nex_rot_euler_diff = expanded_prev_nex_rot_euler_diff.contiguous().view(-1).contiguous()
        
        
        else:
            raise ValueError(f"Unrecognized obj motion format: {self.obj_motion_format}")
        
        
        loaded_data = {
            'history_hand_qpos' : history_qpos,
            'future_hand_qpos_motion' : future_qpos_motion,
            'future_hand_qtars': future_qtars,
            'history_hand_qtars': history_qtars,
            'future_hand_qpos': future_qpos,
            'future_obj_euler_diff': prev_nex_rot_euler_diff,
            'hist_obj_rot_diff': last_hist_obj_rot_diff_from_prev
            # 'history_obj_pose': canon_history_object_pose
        }
        
        
        if self.invdyn_w_hand_root_ornt:
            loaded_data.update(
                {
                    'hand_root_ornt': hand_root_ornt
                }
            )
        
        
        if self.hist_context_length > 0:
            # load the history context # # rand # if history context length is larger than 0, then we should add histcontext to the data dict #
            history_context_ts = [ rand_ts - self.hist_context_length + i + 1 for i in range(0, self.hist_context_length) ]
            history_context_ts = torch.tensor(history_context_ts).long()
            context_qpos, context_qtars = tot_states[history_context_ts], tot_qtars[history_context_ts]
            unscaled_context_qpos = self._unscale(context_qpos, self.hand_dof_lower, self.hand_dof_upper).float()
            hist_context = torch.cat(
                [ unscaled_context_qpos, context_qtars ], dim=-1
            )
            loaded_data.update(
                {
                    'hist_context': hist_context
                }
            )
        
        if self.train_q_value_model:
            tot_reward_buf = data['reward_buf']
            history_reward = tot_reward_buf[tot_history_ts] # nn_hist_length x 1
            future_reward = tot_reward_buf[tot_future_ts] # nn_future_ts x 1 
            history_reward = torch.from_numpy(history_reward).float().contiguous().view(-1).contiguous()
            future_reward = torch.from_numpy(future_reward).float().contiguous().view(-1).contiguous()
            
            nex_history_ts = [ rand_ts - self.history_length + i + 2 for i in range(0, self.history_length) ]
            nex_history_ts = [ max(0, cur_ts) for cur_ts in nex_history_ts ]
            nex_history_ts = torch.tensor(nex_history_ts).long() # (nn_history_len, )
            
            nex_hist_qpos = tot_states[nex_history_ts]
            nex_hist_qtars = tot_qtars[nex_history_ts]
            nex_hist_qpos = nex_hist_qpos.contiguous().view(-1).contiguous()
            nex_hist_qtars = nex_hist_qtars.contiguous().view(-1).contiguous()
            nex_hist_state = torch.cat(
                [ nex_hist_qpos, nex_hist_qtars ], dim=-1
            )
            
            nex_future_ts = [ rand_ts + i + 2 for i in range(0, self.future_length) ]
            nex_future_ts = [ min(nn_ts - 1, cur_ts) for cur_ts in nex_future_ts ]
            nex_future_ts = torch.tensor(nex_future_ts).long() # (nn_future_len, )
            
            nex_future_qtars = tot_qtars[nex_future_ts]
            nex_future_qtars = nex_future_qtars.contiguous().view(-1).contiguous()
            
            loaded_data.update(
                {
                    'history_reward': history_reward,
                    'future_reward': future_reward,
                    'nex_state': nex_hist_state,
                    'nex_action': nex_future_qtars
                }
            )
            
            
        ### if train value network ### --- load the values and let the network to predict the values from history ### 
        ### predict the value from what? --- should include the next action ###
        if self.train_value_network:
            # history state, hsitroy actions # with one frame of the next action #
            # rand ts - (self.hist_length) + (self.hist_length - 1) + 1 # = rand_ts
            value_net_history_ts = [ rand_ts - self.history_length + i + 1 for i in range(0, self.history_length) ]
            value_net_history_ts = torch.tensor(value_net_history_ts, dtype=torch.long)
            value_net_hist_state = tot_states[value_net_history_ts]
            value_net_hist_action = tot_qtars[value_net_history_ts]
            
            future_ts = rand_ts + 1
            nex_action = tot_qtars[future_ts]
            
            tot_values = data['value_buf']
            tot_values = torch.from_numpy(tot_values).float() 
            nex_value = tot_values[future_ts]
            
            unscaled_value_net_hist_state = self._unscale(value_net_hist_state, self.hand_dof_lower, self.hand_dof_upper).float()
            # unscaled_nex_action = self._unscale(nex_action, self.hand_dof_lower, self.hand_dof_upper).float() # actions are not unscaled heer #
            
            value_net_hist_info = torch.cat(
                [ unscaled_value_net_hist_state, value_net_hist_action ], dim=-1
            ) # hsit_length x (16 + 16)
            value_net_hist_info = value_net_hist_info.contiguous().view(-1).contiguous()
            
            # waht's the current predicted value? -- futrue value #
            
            loaded_data.update(
                {
                    'value_net_hist_info': value_net_hist_info,
                    'value_net_nex_action': nex_action       ,
                    'value_net_nex_value': nex_value 
                }
            )
            
        
        if self.w_obj_state_history:
            loaded_data.update(
                {
                    'history_obj_pose': canon_history_object_pose
                }
            )
            
        if self.train_finger_pos_tracking_model:
            # get the history hand qpos and qtars -- as the hsitroy context #
            hand_qpos_w_one_frame = torch.cat(
                [tot_states[tot_history_ts][-1: ], tot_states[tot_future_ts][:1]], dim=0
            ) 
            cur_finger_qpos_state = tot_states[tot_history_ts[-1]][self.target_finger_joint_idxes] # finger joint dimension #
            hand_target_finger_pos = self._forward_pk_chain_for_finger_pos(hand_qpos_w_one_frame) # (hist_future_ts, 3)
            cur_finger_motion = hand_target_finger_pos[1:] - hand_target_finger_pos[:-1] # (hist_future_ts - 1, 3)
            finger_pos_w_motion_ref = torch.cat(
                [ cur_finger_qpos_state, cur_finger_motion[0] ], dim=0
            )
            
            # target finger qtars #
            target_finger_qtars = tot_qtars[tot_future_ts[0]][self.target_finger_joint_idxes]
            
            loaded_data.update(
                {
                    'finger_pos_w_motion_ref': finger_pos_w_motion_ref,
                    'target_finger_qtars': target_finger_qtars
                }
            )
        
        
        if self.obj_state_predictor: #
            history_obj_pose = history_object_pose[..., 3:].contiguous().view(-1).contiguous()
            future_obj_pose = future_object_pose[..., 3:].contiguous().view(-1).contiguous()
            control_err = history_qpos.contiguous().view(self.history_length, -1).contiguous() - history_qtars.contiguous().view(self.history_length, -1).contiguous()
            control_err = control_err.contiguous().view(-1).contiguous()
            loaded_data.update(
                {
                    'history_obj_pose': history_obj_pose,
                    'future_obj_pose': future_obj_pose,
                    'control_err': control_err
                }
            )
            
        if self.pred_extrin:
            loaded_data.update(
                {
                    'gt_extrin': cur_extrin,
                    'history_extrin': hist_extrin
                }
            )
            
        loaded_data.update(
            obj_tracking_ref_dict
        )
            
        return loaded_data
                
    def _scale(self, x, lower, upper):
        return (x * (upper - lower) + upper + lower) / 2.0
        
    def _unscale(self, x, lower, upper):
        return (2.0 * x - upper - lower) / (upper - lower)
    
    def _scale_tensor_and_convert_to_numpy(self, x):
        scaled_tensor = self._scale(x.detach().cpu(), self.hand_dof_lower, self.hand_dof_upper)
        return scaled_tensor.float().numpy()
    
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target) where target is the index of the target category.
        """
        
        experience_idx = index // self.per_experience_data_nn
        ts_idx = index % self.per_experience_data_nn
        
        loaded_data = self._load_experiences_stochastic(experience_idx,  ts_idx)
        
        
        if self.invdyn_input_type == 'link_pos':
            cur_history_link_state = loaded_data['history_link_state'][index] # nn_history_length x nn_link x 3
            cur_future_link_motion = self.preload_experiences['future_link_motion'][index] # nn_future_length x nn_link x 3
            cur_history_link_state = cur_history_link_state.contiguous().view(-1).contiguous()
            cur_future_link_motion = cur_future_link_motion.contiguous().view(-1).contiguous()
            cur_future_hand_qtars = self.preload_experiences['future_hand_qtars'][index]
            rt_dict = {
                'state': cur_history_link_state,
                'motion': cur_future_link_motion,
                'action': cur_future_hand_qtars
            }
        elif self.invdyn_input_type == 'hand_qpos':
            cur_history_hand_qpos = self.preload_experiences['history_hand_qpos'][index]
            cur_future_hand_qpos_motion = self.preload_experiences['future_hand_qpos_motion'][index]
            cur_future_hand_qtars = self.preload_experiences['future_hand_qtars'][index]
            rt_dict = {
                'state': cur_history_hand_qpos,
                'motion': cur_future_hand_qpos_motion,
                'action': cur_future_hand_qtars
            }
        elif self.invdyn_input_type == 'hand_qpos_qtars': 
            cur_history_hand_qpos = loaded_data['history_hand_qpos']
            cur_future_hand_qpos_motion = loaded_data['future_hand_qpos_motion']
            cur_history_hand_qtars = loaded_data['history_hand_qtars']
            cur_future_hand_qtars = loaded_data['future_hand_qtars']
            cur_future_obj_euler_diff = loaded_data['future_obj_euler_diff']
            
            if 'mode' in loaded_data and loaded_data['mode'] == 'hand':
                hand_cond_mask = torch.tensor([True], dtype=torch.bool)
                obj_cond_mask = torch.tensor([False], dtype=torch.bool)
            else:
                hand_cond_mask = torch.tensor([False], dtype=torch.bool)
                obj_cond_mask = torch.tensor([True], dtype=torch.bool)
            
            
            
            unflatten_future_obj_euler_diff = cur_future_obj_euler_diff.contiguous().view(-1, 3).contiguous()
            future_obj_rot_dir = unflatten_future_obj_euler_diff / torch.clamp(torch.norm(unflatten_future_obj_euler_diff, p=2, dim=-1, keepdim=True), min=1e-6)
            flatten_future_obj_rot_dir = future_obj_rot_dir.contiguous().view(-1).contiguous()
            
            cur_state = torch.cat(
                [ cur_history_hand_qpos, cur_history_hand_qtars ], dim=-1
            )
            rt_dict = {
                'state': cur_state,
                'motion': cur_future_hand_qpos_motion,
                'action': cur_future_hand_qtars,
                'future_hand_qpos': loaded_data['future_hand_qpos'],
                'obj_euler_diff': cur_future_obj_euler_diff,
                'obj_rot_dir': flatten_future_obj_rot_dir,
                'hand_cond_mask': hand_cond_mask,
                'obj_cond_mask': obj_cond_mask,
                'hist_obj_rot_diff': loaded_data['hist_obj_rot_diff']
            }
            
            if self.invdyn_w_hand_root_ornt:
                rt_dict.update(
                    {
                        'hand_root_ornt': loaded_data['hand_root_ornt']
                    }
                )
            
            if self.hist_context_length > 0:
                rt_dict.update(
                    {
                        'hist_context': loaded_data['hist_context']
                    }
                )
            
            
            if self.train_finger_pos_tracking_model:
                rt_dict.update(
                    { 'finger_pos_w_motion_ref': loaded_data['finger_pos_w_motion_ref'],
                        'target_finger_qtars': loaded_data['target_finger_qtars']
                     }
                )
            
            if self.train_q_value_model:
                cur_history_reward = loaded_data['history_reward']
                cur_future_reward = loaded_data['future_reward']
                cur_nex_state = loaded_data['nex_state']
                cur_nex_action = loaded_data['nex_action']
                rt_dict.update(
                    {
                        'history_reward': cur_history_reward,
                        'future_reward': cur_future_reward,
                        'nex_state': cur_nex_state,
                        'nex_action': cur_nex_action
                    }
                )
            
            if self.train_value_network:
                rt_dict.update(
                    {
                        'value_net_hist_info': loaded_data['value_net_hist_info'],
                        'value_net_nex_action': loaded_data['value_net_nex_action'],
                        'value_net_nex_value': loaded_data['value_net_nex_value']
                    }
                )
            
            if self.invdyn_train_obj_motion_pred_model:
                rt_dict.update(
                    {
                        'action': loaded_data['hist_obj_rot_diff']
                    }
                )
            
            if self.obj_state_predictor:
                control_err = loaded_data['control_err']
                history_obj_pose = loaded_data['history_obj_pose']
                future_obj_pose = loaded_data['future_obj_pose']
                state = torch.cat(
                    [cur_history_hand_qpos, cur_history_hand_qtars, control_err, history_obj_pose], dim=-1
                )
                rt_dict = {
                    'state': state,
                    'action': future_obj_pose, 
                    'obj_euler_diff': state,
                }

            if self.pred_extrin:
                rt_dict.update(
                    {
                        'gt_extrin': loaded_data['gt_extrin'],
                        'history_extrin': loaded_data['history_extrin']
                    }
                )
            
            if 'cur_data_mode_tensor' in loaded_data:
                rt_dict.update(
                    {
                        'cur_data_mode_tensor': loaded_data['cur_data_mode_tensor'],
                        'rot_dir': loaded_data['rot_dir'],
                        'ang_vel': loaded_data['ang_vel'],
                    }
                )
                
        elif self.invdyn_input_type == 'hand_qpos_qtars_obj_pose':
            cur_history_hand_qpos = loaded_data['history_hand_qpos']
            cur_future_hand_qpos_motion = loaded_data['future_hand_qpos_motion']
            cur_history_hand_qtars = loaded_data['history_hand_qtars']
            cur_future_hand_qtars = loaded_data['future_hand_qtars']
            cur_future_obj_euler_diff = loaded_data['future_obj_euler_diff']
            cur_history_obj_pose = loaded_data['history_obj_pose']
            
            unflatten_future_obj_euler_diff = cur_future_obj_euler_diff.contiguous().view(-1, 3).contiguous()
            future_obj_rot_dir = unflatten_future_obj_euler_diff / torch.clamp(torch.norm(unflatten_future_obj_euler_diff, p=2, dim=-1, keepdim=True), min=1e-6)
            flatten_future_obj_rot_dir = future_obj_rot_dir.contiguous().view(-1).contiguous()
            
            # chagne the invdyn input type --- need to change some model input dimension settings as well  #
            cur_state = torch.cat( # they are tensors that have been flatten
                [ cur_history_hand_qpos, cur_history_hand_qtars, cur_history_obj_pose ], dim=-1
            )
            rt_dict = {
                'state': cur_state, 
                'motion': cur_future_hand_qpos_motion,
                'action': cur_future_hand_qtars,
                'obj_euler_diff': cur_future_obj_euler_diff,
                'obj_rot_dir': flatten_future_obj_rot_dir
            }
        
        return rt_dict

    def __len__(self):
        return self.length


class ControlSeqWorldModel(data.Dataset):
    """
    `LSUN <https://www.yf.io/p/lsun>`_ dataset.

    Args:
        root (string): Root directory for the database files.
        classes (string or list): One of {'train', 'val', 'test'} or a list of
            categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    
    def __init__(self, data_path, obs_type, history_length=4, future_length=2, res=8, config=None, split='train'):
        super(ControlSeqWorldModel, self).__init__()
        
        # 
        self.data_path = data_path
        self.obs_type = obs_type
        self.history_length = history_length    
        self.future_length = future_length
        self.res = res 
        self.split = split
        self.train_split_factor = 0.9
        self.model_type = config.model.subtype
        
        # control seq world model #
        self.config = config
        self.invdyn_input_type = self.config.invdyn.obs_type
        self.diffusion_rep = self.config.model.diffusion_rep # ['link_motion', 'qpos_motion']
        self.data_type = self.config.data_type # ['tracking', 'hora']
        self.w_obj_state_history = self.config.invdyn.w_obj_state_history
        
        
        self.action_type = self.config.invdyn.action_type
        self.mask_out_obj_motion = self.config.invdyn.mask_out_obj_motion

        self.use_obj_motion_norm_command = self.config.invdyn.use_obj_motion_norm_command
        # in ['ori_motion', 'motion_projected_to_mode', 'motion_dir']
        self.obj_motion_format = self.config.invdyn.obj_motion_format 

        self.load_experience_via_mode = self.config.invdyn.load_experience_via_mode
        
        self.finger_idx = self.config.invdyn.finger_idx
        self.joint_idx = self.config.invdyn.joint_idx
        
        self.wm_history_length = self.config.invdyn.wm_history_length
        self.hist_context_length = self.config.invdyn.hist_context_length
        self.hist_context_finger_idx = self.config.invdyn.hist_context_finger_idx
        
        self.wm_pred_joint_idx = self.config.invdyn.wm_pred_joint_idx
        self.add_nearing_neighbour = self.config.invdyn.add_nearing_neighbour
        self.add_nearing_finger = self.config.invdyn.add_nearing_finger
        
        self.wm_as_invdyn_prediction = self.config.invdyn.wm_as_invdyn_prediction
        
        self.stack_wm_history = self.config.invdyn.stack_wm_history
        self.multi_joint_single_wm = self.config.invdyn.multi_joint_single_wm # 
        self.multi_finger_single_wm = self.config.invdyn.multi_finger_single_wm
        self.single_hand_wm = self.config.invdyn.single_hand_wm
        
        self.fullhand_wobjstate_wm = self.config.invdyn.fullhand_wobjstate_wm
        
        tot_data_path = self.data_path.split("AND")
        self.experience_idx_to_data_fn = {}
        self.tot_experience_idx = 0
        
        train_ratio = 0.9
        train_ratio = 0.99
        
        for cur_data_path in tot_data_path:
            cur_idx_to_data_fn = np.load(cur_data_path, allow_pickle=True).item()
            
            tot_traj_keys = list(cur_idx_to_data_fn.keys())
            if split == 'train':
                train_keys = tot_traj_keys[ : int(len(tot_traj_keys) * train_ratio)]
            else:
                train_keys = tot_traj_keys[ int(len(tot_traj_keys) * train_ratio) :  ]
            
            for cur_idx in train_keys:
                self.experience_idx_to_data_fn[self.tot_experience_idx] = cur_idx_to_data_fn[cur_idx]
                self.tot_experience_idx += 1

        
        self.per_experience_data_nn = 400 - 4 
        self.nn_experience = len(self.experience_idx_to_data_fn)
        
        
        self.length = self.nn_experience * self.per_experience_data_nn
        
        
        self.use_relative_target = self.config.invdyn.use_relative_target
        self.add_obs_noise_scale = self.config.invdyn.add_obs_noise_scale
        self.add_action_noise_scale = self.config.invdyn.add_action_noise_scale
        
        #### Extrin prediction setting #### ## pred ## 
        self.pred_extrin = self.config.invdyn.pred_extrin
        # self.extrin_history_length = 50
        self.extrin_history_length = 30
        
        self.hand_dof_lower = torch.from_numpy(np.array([
            -0.3140, -1.0470, -0.5060, -0.3660, -0.3490, -0.4700, -1.2000, -1.3400,
            -0.3140, -1.0470, -0.5060, -0.3660, -0.3140, -1.0470, -0.5060, -0.3660
        ]))
        self.hand_dof_upper = torch.from_numpy(np.array([
            2.2300, 1.0470, 1.8850, 2.0420, 2.0940, 2.4430, 1.9000, 1.8800, 2.2300,
            1.0470, 1.8850, 2.0420, 2.2300, 1.0470, 1.8850, 2.0420
        ]))
    
        self.relative_qtars_scale_coef = 24
        

        self.normalize_input = self.config.invdyn.normalize_input
        self.normalize_output = self.config.invdyn.normalize_output
        self.obj_state_predictor = self.config.invdyn.obj_state_predictor
    

        self.pk_joint_names = ['1', '0', '2', '3', '5', '4', '6', '7', '9', '8', '10', '11', '12', '13', '14', '15']
        self.isaacgym_joint_names =  ['1', '0', '2', '3', '12', '13', '14', '15', '5', '4', '6', '7', '9', '8', '10', '11']
        self.idxes_from_pk_to_isaacgym = [self.pk_joint_names.index(cur_joint_name) for cur_joint_name in self.isaacgym_joint_names]
        self.idxes_from_pk_to_isaacgym = np.array(self.idxes_from_pk_to_isaacgym).astype(np.int32)
    
    
        if self.joint_idx >= 0:
            finger_st_idx, finger_ed_idx = self.joint_idx, self.joint_idx + 1
            
            if self.add_nearing_neighbour:
                if self.joint_idx % 4 == 0:
                    self.bf_joint_idx = self.joint_idx
                else:
                    self.bf_joint_idx = self.joint_idx - 1
                if (self.joint_idx + 1) % 4 == 0:
                    self.af_joint_idx = self.joint_idx
                else:
                    self.af_joint_idx = self.joint_idx + 1
            
        elif self.finger_idx >= 0:
            finger_st_idx, finger_ed_idx = self.finger_idx * 4, (self.finger_idx + 1) * 4
        else:
            finger_st_idx, finger_ed_idx = 0, 16
        
        self.target_finger_joint_idxes = [ _ for _ in range(finger_st_idx, finger_ed_idx) ]
        self.target_finger_joint_idxes = torch.tensor(self.target_finger_joint_idxes, dtype=torch.long)
        # nearing finger joint idxes #
        
        self.nex_joint_idxes = [1, 2, 3, 2, 5, 6, 7, 6, 9, 10, 11, 10, 12, 13, 14, 13]
        self.nex_joint_idxes = [1, 0, 1, 2, 
                                5, 4, 5, 6, 
                                9, 8, 9, 10, 
                                13, 12, 13, 14]
        self.nex_joint_idxes = torch.tensor(self.nex_joint_idxes, dtype=torch.long)
        self.pred_nearing_joint = self.config.invdyn.pred_nearing_joint
    
    
    def _load_experiences_stochastic(self, data_idx, ts_idx):
        data_fn = self.experience_idx_to_data_fn[data_idx ]
        
        if isinstance(data_fn, tuple):
            data_fn, cur_data_mode = data_fn
        else:
            cur_data_mode = None
        
        
        data = np.load(data_fn, allow_pickle=True).item()
        
        if 'qpos' in data:
            tot_states = data['qpos'] # nn_ts x nn_hand_dof #
            tot_qtars = data['qtars'] # nn_ts x nn_hand_dof #
            # print(f"tot_states: {tot_states.shape}, tot_qtars: {tot_qtars.shape}")
            if len(tot_states.shape) == 3:
                tot_states = tot_states[:, 0]
                tot_qtars = tot_qtars[:, 0]
            if 'object_pose' in data:
                tot_obj_pose = data['object_pose']
            else:
                tot_obj_pose = np.array([0, 0, 0, 0, 0, 0, 1], dtype=np.float32).reshape(1, 7).repeat(tot_states.shape[0], axis=0) # nn_ts x 7 #
        else:
            tot_states = data['shadow_hand_dof_pos']
            tot_qtars = data['shadow_hand_dof_tars'] # nn_ts x nn_hand_dof #
            tot_obj_pose = data['object_pose'] # nn_ts x 7 #
            if self.w_obj_state_history:
                tot_link_pos = data['link_pos'] # nn_ts x nn_links x 3 #
        
        tot_states = torch.from_numpy(tot_states).float()
        tot_qtars = torch.from_numpy(tot_qtars).float()
        tot_obj_pose = torch.from_numpy(tot_obj_pose).float()
        if self.w_obj_state_history:
            tot_link_pos = torch.from_numpy(tot_link_pos).float()
        
        
        nn_ts = tot_states.shape[0]
        rand_ts = ts_idx
        
        # cur state -> tot_states, tot_qtars, tot_obj_pose, tot_link_pos #
        
        if self.joint_idx >= 0:
            finger_st_idx, finger_ed_idx = self.joint_idx, self.joint_idx + 1
        elif self.finger_idx >= 0:
            finger_st_idx, finger_ed_idx = self.finger_idx * 4, (self.finger_idx + 1) * 4
        else:
            finger_st_idx, finger_ed_idx = 0, tot_states.size(-1)
        
        self.target_finger_joint_idxes = [ _ for _ in range(finger_st_idx, finger_ed_idx) ]
        self.target_finger_joint_idxes = torch.tensor(self.target_finger_joint_idxes, dtype=torch.long)
        
        if self.wm_history_length == 1:
            cur_state = tot_states[rand_ts + 1] # (16,)
            cur_action = tot_qtars[rand_ts + 2] # (16, )
            nex_state = tot_states[rand_ts + 2] # (16,)
            
            cur_hist_action  = tot_qtars[rand_ts + 1]
            
            
            if (self.add_obs_noise_scale > 0 or self.add_action_noise_scale > 0):
                cur_state_noise = torch.randn_like(cur_state) * self.add_obs_noise_scale
                cur_action_noise = torch.randn_like(cur_action) * self.add_action_noise_scale
                cur_state = cur_state + cur_state_noise
                cur_action = cur_action + cur_action_noise
            
            
            unscaled_state = self._unscale(cur_state, self.hand_dof_lower, self.hand_dof_upper).float()
            unscaled_action = self._unscale(cur_action, self.hand_dof_lower, self.hand_dof_upper).float()
            unscaled_nex_state = self._unscale(nex_state, self.hand_dof_lower, self.hand_dof_upper).float()
            
            if self.wm_pred_joint_idx >= 0:
                unscaled_nex_state = unscaled_nex_state[self.wm_pred_joint_idx: self.wm_pred_joint_idx + 1]
            elif self.joint_idx >= 0:
                finger_st_idx, finger_ed_idx = self.joint_idx, self.joint_idx + 1
                unscaled_nex_state = unscaled_nex_state[finger_st_idx: finger_ed_idx]
            elif self.finger_idx >= 0:
                finger_st_idx, finger_ed_idx = self.finger_idx * 4, (self.finger_idx + 1) * 4
                unscaled_nex_state = unscaled_nex_state[finger_st_idx: finger_ed_idx]
            
            if self.joint_idx >= 0:
                finger_st_idx, finger_ed_idx = self.joint_idx, self.joint_idx + 1
                unscaled_state = unscaled_state[finger_st_idx: finger_ed_idx]
                unscaled_action = unscaled_action[finger_st_idx: finger_ed_idx]
                # unscaled_nex_state = unscaled_nex_state[finger_st_idx: finger_ed_idx]
                
                cur_hist_action = cur_hist_action[finger_st_idx: finger_ed_idx]
                
                scaled_action = cur_action[finger_st_idx: finger_ed_idx]
                
            elif self.finger_idx >= 0:
                finger_st_idx, finger_ed_idx = self.finger_idx * 4, (self.finger_idx + 1) * 4
                unscaled_state = unscaled_state[finger_st_idx: finger_ed_idx]
                unscaled_action = unscaled_action[finger_st_idx: finger_ed_idx]
                # unscaled_nex_state = unscaled_nex_state[finger_st_idx: finger_ed_idx]
                
                cur_hist_action = cur_hist_action[finger_st_idx: finger_ed_idx]
                
                scaled_action = cur_action[finger_st_idx: finger_ed_idx]
                
        elif self.wm_history_length > 1:
            histor_ts = [ _ for _ in range( rand_ts + 1 - self.wm_history_length + 1 , rand_ts + 2 ) ]
            # ..., rand_ts + 1
            histor_ts = [ max(0, cur_ts) for cur_ts in histor_ts ]
            histor_ts = torch.tensor(histor_ts, dtype=torch.long) # (nn_history_length, )
            cur_state = tot_states[histor_ts] # nn_history_length x nn_hand_dof
            action_history_ts = [ _ for _ in range( rand_ts + 2 - self.wm_history_length + 1 , rand_ts + 3 ) ]
            action_history_ts = [ max(0, cur_ts) for cur_ts in action_history_ts ]
            action_history_ts = torch.tensor(action_history_ts, dtype=torch.long) # (nn_history_length, )
            cur_action = tot_qtars[action_history_ts] # nn_history_length x nn_hand_dof
            nex_state = tot_states[rand_ts + 2] # (16, ) 
            
            if self.pred_nearing_joint:
                # nex_state = nex_state[self.nex_joint_idxes]
                nex_state = tot_qtars[action_history_ts[-1]][self.nex_joint_idxes]
            
            cur_hist_action = tot_qtars[action_history_ts]
            
            if (self.add_obs_noise_scale > 0 or self.add_action_noise_scale > 0):
                cur_state_noise = torch.randn_like(cur_state) * self.add_obs_noise_scale
                cur_action_noise = torch.randn_like(cur_action) * self.add_action_noise_scale
                cur_state = cur_state + cur_state_noise
                cur_action = cur_action + cur_action_noise
            
            if self.stack_wm_history:
                last_rnd_ts = rand_ts - self.wm_history_length + 1; 
                last_rnd_ts_to_rnd_ts_list = [_ for _ in range(last_rnd_ts, rand_ts + 1)]
                # last_rnd_ts = max(0, last_rnd_ts)
                last_rnd_ts_to_rnd_ts_list = [ max(0, cur_ts) for cur_ts in last_rnd_ts_to_rnd_ts_list ]
                tot_cur_state = []
                tot_cur_action = []
                for cur_rnd_ts in last_rnd_ts_to_rnd_ts_list:
                    cur_history_ts = [ _ for _ in range(cur_rnd_ts + 1 - self.wm_history_length + 1 , cur_rnd_ts + 2 ) ]
                    cur_history_ts = [ max(0, cur_ts) for cur_ts in cur_history_ts ]
                    cur_history_ts = torch.tensor(cur_history_ts, dtype=torch.long) # (nn_history_length, )
                    cur_rnd_ts_state = tot_states[cur_history_ts]
                    cur_action_history_ts = [ _ for _ in range(cur_rnd_ts + 2 - self.wm_history_length  + 1, cur_rnd_ts + 3) ]
                    cur_action_history_ts = [ max(0, cur_ts) for cur_ts in cur_action_history_ts ]
                    cur_action_history_ts = torch.tensor(cur_action_history_ts, dtype=torch.long)
                    cur_rnd_ts_action = tot_qtars[cur_action_history_ts]
                    tot_cur_state.append(cur_rnd_ts_state)
                    tot_cur_action.append(cur_rnd_ts_action)
                tot_cur_state = torch.stack(tot_cur_state, dim=0)
                tot_cur_action = torch.stack(tot_cur_action, dim=0)
                cur_state = tot_cur_state
                cur_action = tot_cur_action
                
                
            
            unscaled_state = self._unscale(cur_state, self.hand_dof_lower, self.hand_dof_upper).float()
            unscaled_action = self._unscale(cur_action, self.hand_dof_lower, self.hand_dof_upper).float()
            unscaled_nex_state = self._unscale(nex_state, self.hand_dof_lower, self.hand_dof_upper).float()
            
            
            if self.wm_as_invdyn_prediction: # swap nex_state and nex_action # # nex action #
                unscaled_nex_action = unscaled_action[-1].clone()
                unscaled_action[-1] = unscaled_nex_state.clone()
                unscaled_action[:-1] = unscaled_action[:-1] * 0.0
                unscaled_nex_state = unscaled_nex_action.clone() # 
                
            
            if self.wm_pred_joint_idx >= 0:
                unscaled_nex_state = unscaled_nex_state[self.wm_pred_joint_idx: self.wm_pred_joint_idx + 1]
            elif self.joint_idx >= 0:
                finger_st_idx, finger_ed_idx = self.joint_idx, self.joint_idx + 1
                unscaled_nex_state = unscaled_nex_state[finger_st_idx: finger_ed_idx]
            elif self.finger_idx >= 0:
                finger_st_idx, finger_ed_idx = self.finger_idx * 4, (self.finger_idx + 1) * 4
                unscaled_nex_state = unscaled_nex_state[finger_st_idx: finger_ed_idx]
            
            if self.joint_idx >= 0:
                if self.add_nearing_finger:
                    finger_idx = self.joint_idx // 4
                    unscaled_state = unscaled_state[:, finger_idx * 4: (finger_idx + 1) * 4]
                    unscaled_action = unscaled_action[:, finger_idx * 4: (finger_idx + 1) * 4]
                    cur_hist_action = cur_hist_action[:, finger_idx * 4: (finger_idx + 1) * 4]
                    scaled_action = cur_action[:, finger_idx * 4: (finger_idx + 1) * 4]
                else:
                    if self.add_nearing_neighbour:
                        bf_joint_cur_state = unscaled_state[-1:, self.bf_joint_idx: self.bf_joint_idx + 1]
                        bf_joint_cur_action = unscaled_action[-1, self.bf_joint_idx: self.bf_joint_idx + 1]
                        af_joint_cur_state = unscaled_state[-1, self.af_joint_idx: self.af_joint_idx + 1]
                        af_joint_cur_action = unscaled_action[-1, self.af_joint_idx: self.af_joint_idx + 1]
                    
                    finger_st_idx, finger_ed_idx = self.joint_idx, self.joint_idx + 1
                    unscaled_state = unscaled_state[:, finger_st_idx: finger_ed_idx] # nn_hist_length x nn_hand_dof
                    unscaled_action = unscaled_action[:, finger_st_idx: finger_ed_idx]
                    # unscaled_nex_state = unscaled_nex_state[finger_st_idx: finger_ed_idx]
                    
                    cur_hist_action = cur_hist_action[:, finger_st_idx: finger_ed_idx]
                    
                    scaled_action = cur_action[:, finger_st_idx: finger_ed_idx]
                
            elif self.finger_idx >= 0:
                finger_st_idx, finger_ed_idx = self.finger_idx * 4, (self.finger_idx + 1) * 4
                unscaled_state = unscaled_state[:, finger_st_idx: finger_ed_idx]
                unscaled_action = unscaled_action[:, finger_st_idx: finger_ed_idx]
                # unscaled_nex_state = unscaled_nex_state[finger_st_idx: finger_ed_idx]
                cur_hist_action = cur_hist_action[:, finger_st_idx: finger_ed_idx]
                scaled_action  = cur_action[:, finger_st_idx: finger_ed_idx]
            else:
                scaled_action = cur_action[:]
                
            if self.stack_wm_history:
                unscaled_state = unscaled_state.contiguous().view(unscaled_state.size(0), -1).contiguous()
                unscaled_action = unscaled_action.contiguous().view(unscaled_action.size(0), -1).contiguous()
                
                if self.fullhand_wobjstate_wm:
                    nex_obj_pose_state = tot_obj_pose[rand_ts + 2][3:]
                    unscaled_nex_state = torch.cat([unscaled_nex_state, nex_obj_pose_state], dim=-1)
                    cur_obj_pose_state = tot_obj_pose[rand_ts + 1][3:]
                
            else:
                if self.multi_joint_single_wm or self.multi_finger_single_wm or self.single_hand_wm or self.fullhand_wobjstate_wm:
                    if self.fullhand_wobjstate_wm:
                        cur_obj_pose_state = tot_obj_pose[rand_ts + 1][3:]
                        unscaled_state = torch.cat([ unscaled_state.contiguous().view(-1).contiguous(), cur_obj_pose_state ], dim=-1)
                        unscaled_action = unscaled_action
                        
                        nex_obj_pose_state = tot_obj_pose[rand_ts + 2][3:]
                        unscaled_nex_state = torch.cat([unscaled_nex_state, nex_obj_pose_state], dim=-1)
                        
                    else:
                        unscaled_state = unscaled_state
                        unscaled_action = unscaled_action
                else:
                    unscaled_state = unscaled_state.contiguous().view(-1).contiguous()
                    unscaled_action = unscaled_action.contiguous().view(-1).contiguous()
            
            cur_hist_action = cur_hist_action.contiguous().view(-1).contiguous()
            scaled_action = scaled_action.contiguous().view(-1).contiguous()
            
            if self.add_nearing_neighbour:
                # bf_joint_cur_state, bf_joint_cur_action, af_joint_cur_state, af_joint_cur_action
                bf_joint_cur_state = bf_joint_cur_state.contiguous().view(-1).contiguous()
                bf_joint_cur_action = bf_joint_cur_action.contiguous().view(-1).contiguous()
                af_joint_cur_state = af_joint_cur_state.contiguous().view(-1).contiguous()
                af_joint_cur_action = af_joint_cur_action.contiguous().view(-1).contiguous()
                
                unscaled_state = torch.cat([unscaled_state, bf_joint_cur_state, af_joint_cur_state])
                unscaled_action = torch.cat([unscaled_action, bf_joint_cur_action, af_joint_cur_action])
                
                # print(f"unscaled_state: {unscaled_state.size()}, unscaled_action: {unscaled_action.size()}, bf_joint_cur_state: {bf_joint_cur_state.size()}, bf_joint_cur_action: {bf_joint_cur_action.size()}, af_joint_cur_action: {af_joint_cur_action.size()}")
        
        
        loaded_data = {
            'state': unscaled_state,
            'action': unscaled_action,
            'nex_state': unscaled_nex_state,
            'scaled_action': scaled_action,
            'hist_action': cur_hist_action
        }
        
        if self.stack_wm_history and self.fullhand_wobjstate_wm:
            loaded_data['cur_obj_pose_state'] = cur_obj_pose_state
        
        if self.hist_context_length > 0:
            hist_ts_arr = [ _ for _ in range( rand_ts + 2 - self.hist_context_length , rand_ts + 2 ) ]
            hist_ts_arr = [ max(0, cur_ts) for cur_ts in hist_ts_arr ]
            hist_ts_arr = torch.tensor(hist_ts_arr, dtype=torch.long) # (nn_history_length, )
            
            if self.hist_context_finger_idx >= 0:
                hist_context_finger_jt_st_idx, hist_context_finger_jt_ed_idx = self.hist_context_finger_idx * 4, (self.hist_context_finger_idx + 1) * 4
                hist_context_finger_jt_idxes = [ _ for _ in range(hist_context_finger_jt_st_idx, hist_context_finger_jt_ed_idx) ]
                hist_context_finger_jt_idxes = torch.tensor(hist_context_finger_jt_idxes, dtype=torch.long) # (nn_history_length, 
            else:
                hist_context_finger_jt_idxes = [ _ for _ in range(0, tot_states.size(-1)) ]
                hist_context_finger_jt_idxes = torch.tensor(hist_context_finger_jt_idxes, dtype=torch.long) # (nn_history_length, )
            
            # hist_state = tot_states[hist_ts_arr] # nn_history_length x nn_hand_dof
            # hist_action = tot_qtars[hist_ts_arr] # nn_history_length x nn_hand_dof
            
            hist_state = tot_states[hist_ts_arr] # nn_history_length x nn_hand_dof
            hist_action = tot_qtars[hist_ts_arr] # nn_history_length x nn_hand_dof
            
            hist_state = hist_state[..., hist_context_finger_jt_idxes]
            hist_action = hist_action[..., hist_context_finger_jt_idxes]
            
            hist_state = hist_state.contiguous().view(-1).contiguous()
            hist_action = hist_action.contiguous().view(-1).contiguous()
            
            if self.add_obs_noise_scale > 0 or self.add_action_noise_scale > 0:
                hist_state_noise = torch.randn_like(hist_state) * self.add_obs_noise_scale
                hist_action_noise = torch.randn_like(hist_action) * self.add_action_noise_scale
                hist_state = hist_state + hist_state_noise
                hist_action = hist_action + hist_action_noise
            
            loaded_data['hist_state'] = hist_state
            loaded_data['hist_action'] = hist_action
        
        return loaded_data
    
    def unscale_states(self, scaled_states):
        if self.joint_idx >= 0:
            finger_st_idx, finger_ed_idx = self.joint_idx, self.joint_idx + 1
        elif self.finger_idx >= 0:
            finger_st_idx, finger_ed_idx = self.finger_idx * 4, (self.finger_idx + 1) * 4
        else:
            finger_st_idx, finger_ed_idx = 0, self.hand_dof_lower.shape[0]
        unscaled_states = self._unscale(scaled_states, self.hand_dof_lower[finger_st_idx: finger_ed_idx].to(scaled_states.device), self.hand_dof_upper[finger_st_idx: finger_ed_idx].to(scaled_states.device)).float()
        return unscaled_states
        
    
    def _scale(self, x, lower, upper):
        return (x * (upper - lower) + upper + lower) / 2.0
    
    def _unscale(self, x, lower, upper):
        return (2.0 * x - upper - lower) / (upper - lower)
    
    def _scale_tensor_and_convert_to_numpy(self, x):
        scaled_tensor = self._scale(x.detach().cpu(), self.hand_dof_lower, self.hand_dof_upper)
        return scaled_tensor.float().numpy()
    
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target) where target is the index of the target category.
        """
        # load experiences 
        # 
        experience_idx = index // self.per_experience_data_nn
        ts_idx = index % self.per_experience_data_nn
        
        loaded_data = self._load_experiences_stochastic(experience_idx,  ts_idx)
        
        
        return loaded_data

    def __len__(self):
        return self.length

